{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6f01ea8a",
   "metadata": {
    "id": "6f01ea8a"
   },
   "source": [
    "# Torch-Rechub Tutorial: DeepFM\n",
    "\n",
    "- 场景：精排（CTR预测） \n",
    "- 模型：DeepFM\n",
    "- 数据：Criteo广告数据集\n",
    "\n",
    "- 学习目标\n",
    "  - 学会使用torch-rechub调用DeepFM进行CTR预测\n",
    "  - 学会基于torch-rechub的基础模块，使用pytorch复现DeepFM模型\n",
    "  \n",
    "- 学习材料：\n",
    "  - 模型思想介绍：https://datawhalechina.github.io/fun-rec/#/ch02/ch2.2/ch2.2.3/DeepFM\n",
    "  - rechub模型代码：https://github.com/datawhalechina/torch-rechub/blob/main/torch_rechub/models/ranking/deepfm.py\n",
    "  - 数据集详细描述：https://github.com/datawhalechina/torch-rechub/tree/main/examples/ranking\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "DGRGIGRJc9Gr",
   "metadata": {
    "id": "DGRGIGRJc9Gr"
   },
   "outputs": [],
   "source": [
    "#安装torch-rechub\n",
    "#!pip install torch-rechub"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d63b81c",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "6d63b81c",
    "outputId": "11701743-1fa6-4b60-996a-50e371cec173"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "from torch_rechub.models.ranking import WideDeep, DeepFM, DCN\n",
    "from torch_rechub.trainers import CTRTrainer\n",
    "from torch_rechub.basic.features import DenseFeature, SparseFeature\n",
    "from torch_rechub.utils.data import DataGenerator\n",
    "from tqdm import tqdm\n",
    "from sklearn.preprocessing import MinMaxScaler, LabelEncoder\n",
    "torch.manual_seed(2022) #固定随机种子"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44a036c1",
   "metadata": {
    "id": "44a036c1"
   },
   "source": [
    "### 数据集介绍\n",
    "该数据集是Criteo Labs发布的在线广告数据集。 它包含数百万个展示广告的点击反馈记录，该数据可作为点击率(CTR)预测的基准。 数据集具有40个特征，第一列是标签，其中值1表示已点击广告，而值0表示未点击广告。 其他特征包含13个dense特征和26个sparse特征。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f2fad7ad",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 300
    },
    "id": "f2fad7ad",
    "outputId": "fb2428b9-6377-4845-f5a2-77a5d76d2dcc"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>label</th>\n",
       "      <th>I1</th>\n",
       "      <th>I2</th>\n",
       "      <th>I3</th>\n",
       "      <th>I4</th>\n",
       "      <th>I5</th>\n",
       "      <th>I6</th>\n",
       "      <th>I7</th>\n",
       "      <th>I8</th>\n",
       "      <th>I9</th>\n",
       "      <th>...</th>\n",
       "      <th>C17</th>\n",
       "      <th>C18</th>\n",
       "      <th>C19</th>\n",
       "      <th>C20</th>\n",
       "      <th>C21</th>\n",
       "      <th>C22</th>\n",
       "      <th>C23</th>\n",
       "      <th>C24</th>\n",
       "      <th>C25</th>\n",
       "      <th>C26</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0</td>\n",
       "      <td>104.0</td>\n",
       "      <td>27.0</td>\n",
       "      <td>1990.0</td>\n",
       "      <td>142.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>32.0</td>\n",
       "      <td>37.0</td>\n",
       "      <td>...</td>\n",
       "      <td>e5ba7672</td>\n",
       "      <td>25c88e42</td>\n",
       "      <td>21ddcdc9</td>\n",
       "      <td>b1252a9d</td>\n",
       "      <td>0e8585d2</td>\n",
       "      <td>NaN</td>\n",
       "      <td>32c7478e</td>\n",
       "      <td>0d4a6d1a</td>\n",
       "      <td>001f3601</td>\n",
       "      <td>92c878de</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-1</td>\n",
       "      <td>63.0</td>\n",
       "      <td>40.0</td>\n",
       "      <td>1470.0</td>\n",
       "      <td>61.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>37.0</td>\n",
       "      <td>46.0</td>\n",
       "      <td>...</td>\n",
       "      <td>e5ba7672</td>\n",
       "      <td>d3303ea5</td>\n",
       "      <td>21ddcdc9</td>\n",
       "      <td>b1252a9d</td>\n",
       "      <td>7633c7c8</td>\n",
       "      <td>NaN</td>\n",
       "      <td>32c7478e</td>\n",
       "      <td>17f458f7</td>\n",
       "      <td>001f3601</td>\n",
       "      <td>71236095</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>370</td>\n",
       "      <td>4.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1787.0</td>\n",
       "      <td>65.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>25.0</td>\n",
       "      <td>489.0</td>\n",
       "      <td>...</td>\n",
       "      <td>3486227d</td>\n",
       "      <td>642f2610</td>\n",
       "      <td>55dd3565</td>\n",
       "      <td>b1252a9d</td>\n",
       "      <td>5c8dc711</td>\n",
       "      <td>NaN</td>\n",
       "      <td>423fab69</td>\n",
       "      <td>45ab94c8</td>\n",
       "      <td>2bf691b1</td>\n",
       "      <td>c84c4aec</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>19.0</td>\n",
       "      <td>10</td>\n",
       "      <td>30.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>33.0</td>\n",
       "      <td>47.0</td>\n",
       "      <td>126.0</td>\n",
       "      <td>...</td>\n",
       "      <td>e5ba7672</td>\n",
       "      <td>a78bd508</td>\n",
       "      <td>21ddcdc9</td>\n",
       "      <td>5840adea</td>\n",
       "      <td>c2a93b37</td>\n",
       "      <td>NaN</td>\n",
       "      <td>32c7478e</td>\n",
       "      <td>1793a828</td>\n",
       "      <td>e8b83407</td>\n",
       "      <td>2fede552</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0</td>\n",
       "      <td>36.0</td>\n",
       "      <td>22.0</td>\n",
       "      <td>4684.0</td>\n",
       "      <td>217.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>35.0</td>\n",
       "      <td>135.0</td>\n",
       "      <td>...</td>\n",
       "      <td>e5ba7672</td>\n",
       "      <td>7ce63c71</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>af5dc647</td>\n",
       "      <td>NaN</td>\n",
       "      <td>dbb486d7</td>\n",
       "      <td>1793a828</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 40 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   label    I1   I2     I3    I4      I5     I6    I7    I8     I9  ...  \\\n",
       "0      0   0.0    0  104.0  27.0  1990.0  142.0   4.0  32.0   37.0  ...   \n",
       "1      0   0.0   -1   63.0  40.0  1470.0   61.0   4.0  37.0   46.0  ...   \n",
       "2      0   0.0  370    4.0   1.0  1787.0   65.0  14.0  25.0  489.0  ...   \n",
       "3      1  19.0   10   30.0  10.0     1.0    3.0  33.0  47.0  126.0  ...   \n",
       "4      0   0.0    0   36.0  22.0  4684.0  217.0   9.0  35.0  135.0  ...   \n",
       "\n",
       "        C17       C18       C19       C20       C21  C22       C23       C24  \\\n",
       "0  e5ba7672  25c88e42  21ddcdc9  b1252a9d  0e8585d2  NaN  32c7478e  0d4a6d1a   \n",
       "1  e5ba7672  d3303ea5  21ddcdc9  b1252a9d  7633c7c8  NaN  32c7478e  17f458f7   \n",
       "2  3486227d  642f2610  55dd3565  b1252a9d  5c8dc711  NaN  423fab69  45ab94c8   \n",
       "3  e5ba7672  a78bd508  21ddcdc9  5840adea  c2a93b37  NaN  32c7478e  1793a828   \n",
       "4  e5ba7672  7ce63c71       NaN       NaN  af5dc647  NaN  dbb486d7  1793a828   \n",
       "\n",
       "        C25       C26  \n",
       "0  001f3601  92c878de  \n",
       "1  001f3601  71236095  \n",
       "2  2bf691b1  c84c4aec  \n",
       "3  e8b83407  2fede552  \n",
       "4       NaN       NaN  \n",
       "\n",
       "[5 rows x 40 columns]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_path = '../examples/ranking/data/criteo/criteo_sample.csv'\n",
    "data = pd.read_csv(data_path)  \n",
    "#data = pd.read_csv(data_path, compression=\"gzip\") #if the raw_data is .gz file\n",
    "data.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "HbtXRmEHKup0",
   "metadata": {
    "id": "HbtXRmEHKup0"
   },
   "source": [
    "### 特征工程\n",
    "- Dense特征：又称数值型特征，例如薪资、年龄。 本教程中对Dense特征进行两种操作：\n",
    "  - MinMaxScaler归一化，使其取值在[0,1]之间\n",
    "  - 将其离散化成新的Sparse特征\n",
    "- Sparse特征：又称类别型特征，例如性别、学历。本教程中对Sparse特征直接进行LabelEncoder编码操作，将原始的类别字符串映射为数值，在模型中将为每一种取值生成Embedding向量。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ea5186c5",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ea5186c5",
    "outputId": "fc99af3e-3207-453d-eae1-22b7222e9152"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 13/13 [00:00<00:00, 855.60it/s]\n",
      "100%|██████████| 39/39 [00:00<00:00, 2674.02it/s]\n"
     ]
    }
   ],
   "source": [
    "dense_cols= [f for f in data.columns.tolist() if f[0] == \"I\"] #以I开头的特征名为dense特征\n",
    "sparse_cols = [f for f in data.columns.tolist() if f[0] == \"C\"]  #以C开头的特征名为sparse特征\n",
    "\n",
    "data[dense_cols] = data[dense_cols].fillna(0) #填充空缺值\n",
    "data[sparse_cols] = data[sparse_cols].fillna('-996')\n",
    "\n",
    "\n",
    "#criteo比赛冠军分享的一种离散化思路，不用纠结其原理，大家也可以试试别的离散化手段\n",
    "def convert_numeric_feature(val):\n",
    "    v = int(val)\n",
    "    if v > 2:\n",
    "        return int(np.log(v)**2)\n",
    "    else:\n",
    "        return v - 2\n",
    "        \n",
    "for col in tqdm(dense_cols):  #将离散化dense特征列设置为新的sparse特征列\n",
    "    sparse_cols.append(col + \"_sparse\")\n",
    "    data[col + \"_sparse\"] = data[col].apply(lambda x: convert_numeric_feature(x))\n",
    "\n",
    "scaler = MinMaxScaler()  #对dense特征列归一化\n",
    "data[dense_cols] = scaler.fit_transform(data[dense_cols])\n",
    "\n",
    "for col in tqdm(sparse_cols):  #sparse特征编码\n",
    "    lbe = LabelEncoder()\n",
    "    data[col] = lbe.fit_transform(data[col])\n",
    "\n",
    "#重点：将每个特征定义为torch-rechub所支持的特征基类，dense特征只需指定特征名，sparse特征需指定特征名、特征取值个数(vocab_size)、embedding维度(embed_dim)\n",
    "dense_features = [DenseFeature(feature_name) for feature_name in dense_cols]\n",
    "sparse_features = [SparseFeature(feature_name, vocab_size=data[feature_name].nunique(), embed_dim=16) for feature_name in sparse_cols]\n",
    "y = data[\"label\"]\n",
    "del data[\"label\"]\n",
    "x = data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "c5490237",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "c5490237",
    "outputId": "2dcdac2e-fcea-4a6a-9baa-879610ce2fdc"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "the samples of train : val : test are  80 : 11 : 24\n"
     ]
    }
   ],
   "source": [
    "# 构建模型输入所需要的dataloader，区分验证集、测试集，指定batch大小\n",
    "#split_ratio=[0.7,0.1] 指的是训练集占比70%，验证集占比10%，剩下的全部为测试集\n",
    "dg = DataGenerator(x, y) \n",
    "train_dataloader, val_dataloader, test_dataloader = dg.generate_dataloader(split_ratio=[0.7, 0.1], batch_size=256, num_workers=8)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81f1290e",
   "metadata": {
    "id": "81f1290e"
   },
   "source": [
    "### 训练模型\n",
    "\n",
    "训练一个DeepFM模型，只需要指定DeepFM的模型结构参数，学习率等训练参数。\n",
    "对于DeepFM而言，主要参数如下：\n",
    "\n",
    "- deep_features指用deep模块训练的特征（兼容dense和sparse），\n",
    "- fm_features指用fm模块训练的特征，只能传入sparse类型\n",
    "- mlp_params指定deep模块中，MLP层的参数\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "08997423",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "08997423",
    "outputId": "45dbee47-279a-42ee-aeae-24d11429856c"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train: 100%|██████████| 1/1 [00:08<00:00,  8.24s/it]\n",
      "validation: 100%|██████████| 1/1 [00:13<00:00, 13.86s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0 validation: auc: 0.3333333333333333\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "validation: 100%|██████████| 1/1 [00:07<00:00,  7.94s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test auc: 0.768421052631579\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from torch_rechub.models.ranking import DeepFM\n",
    "from torch_rechub.trainers import CTRTrainer\n",
    "\n",
    "#定义模型\n",
    "model = DeepFM(\n",
    "        deep_features=dense_features+sparse_features,\n",
    "        fm_features=sparse_features,\n",
    "        mlp_params={\"dims\": [256, 128], \"dropout\": 0.2, \"activation\": \"relu\"},\n",
    "    )\n",
    "\n",
    "# 模型训练，需要学习率、设备等一般的参数，此外我们还支持earlystoping策略，及时发现过拟合\n",
    "ctr_trainer = CTRTrainer(\n",
    "    model,\n",
    "    optimizer_params={\"lr\": 1e-4, \"weight_decay\": 1e-5},\n",
    "    n_epoch=1,\n",
    "    earlystop_patience=3,\n",
    "    device='cpu', #如果有gpu，可设置成cuda:0\n",
    "    model_path='./', #模型存储路径\n",
    ")\n",
    "ctr_trainer.fit(train_dataloader, val_dataloader)\n",
    "\n",
    "# 查看在测试集上的性能\n",
    "auc = ctr_trainer.evaluate(ctr_trainer.model, test_dataloader)\n",
    "print(f'test auc: {auc}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "xdOaLztKy2Ol",
   "metadata": {
    "id": "xdOaLztKy2Ol"
   },
   "source": [
    "### 使用其他的排序模型训练Criteo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "U-qhhoaqzQkE",
   "metadata": {
    "id": "U-qhhoaqzQkE"
   },
   "outputs": [],
   "source": [
    "#定义相应的模型，用同样的方式训练\n",
    "model = WideDeep(wide_features=dense_features, deep_features=sparse_features, mlp_params={\"dims\": [256, 128], \"dropout\": 0.2, \"activation\": \"relu\"})\n",
    "\n",
    "model = DCN(features=dense_features + sparse_features, n_cross_layers=3, mlp_params={\"dims\": [256, 128]})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "XOloY8LzRGqA",
   "metadata": {
    "id": "XOloY8LzRGqA"
   },
   "source": [
    "### 从调包到自定义自己的模型\n",
    "恭喜朋友成功运行了DeepFM模型，并得到了CTR推荐的结果。\n",
    "接下来我们考虑如何实现自己的DeepFM模型。\n",
    "由于FM，MLP，LR，Embedding等基础模块被许多推荐模型共用，因此torch_rechub也帮我们集成好了这些小模块。我们在basic.layers中import即可。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "bmE9ax7xTqzL",
   "metadata": {
    "id": "bmE9ax7xTqzL"
   },
   "outputs": [],
   "source": [
    "from torch_rechub.basic.layers import FM, MLP, LR, EmbeddingLayer\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "t6m_TrdUTuxo",
   "metadata": {
    "id": "t6m_TrdUTuxo"
   },
   "source": [
    "有了基础的模块之后，搭建自己的模型就会很方便了，torch-rechub是基于pytorch的因此我们可以像传统的torch模型一样，定义一个model类，然后写好初始化和farward函数即可。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "72qpQPVbS35e",
   "metadata": {
    "id": "72qpQPVbS35e"
   },
   "outputs": [],
   "source": [
    "class MyDeepFM(torch.nn.Module):\n",
    "  # Deep和FM为两部分，分别处理不同的特征，因此传入的参数要有两种特征，由此我们得到参数deep_features,fm_features\n",
    "  # 此外神经网络类的模型中，基本组成原件为MLP多层感知机，多层感知机的参数也需要传进来，即为mlp_params\n",
    "  def __init__(self, deep_features, fm_features, mlp_params):\n",
    "    super().__init__()\n",
    "    self.deep_features = deep_features\n",
    "    self.fm_features = fm_features\n",
    "    self.deep_dims = sum([fea.embed_dim for fea in deep_features])\n",
    "    self.fm_dims = sum([fea.embed_dim for fea in fm_features])\n",
    "    # LR建模一阶特征交互\n",
    "    self.linear = LR(self.fm_dims)\n",
    "    # FM建模二阶特征交互\n",
    "    self.fm = FM(reduce_sum=True)\n",
    "    # 对特征做嵌入表征\n",
    "    self.embedding = EmbeddingLayer(deep_features + fm_features)\n",
    "    self.mlp = MLP(self.deep_dims, **mlp_params)\n",
    "\n",
    "  def forward(self, x):\n",
    "    input_deep = self.embedding(x, self.deep_features, squeeze_dim=True)  #[batch_size, deep_dims]\n",
    "    input_fm = self.embedding(x, self.fm_features, squeeze_dim=False)  #[batch_size, num_fields, embed_dim]\n",
    "\n",
    "    y_linear = self.linear(input_fm.flatten(start_dim=1))\n",
    "    y_fm = self.fm(input_fm)\n",
    "    y_deep = self.mlp(input_deep)  #[batch_size, 1]\n",
    "    # 最终的预测值为一阶特征交互，二阶特征交互，以及深层模型的组合\n",
    "    y = y_linear + y_fm + y_deep\n",
    "    # 利用sigmoid来将预测得分规整到0,1区间内\n",
    "    return torch.sigmoid(y.squeeze(1))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b5Wt_a3mYOoh",
   "metadata": {
    "id": "b5Wt_a3mYOoh"
   },
   "source": [
    "同样的，可以使用torch-rechub提供的trainer进行模型训练和模型评估"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "I4NOXPxhYT2E",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "I4NOXPxhYT2E",
    "outputId": "5949968d-54ad-46de-be6d-b2c80b6d9746"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train: 100%|██████████| 1/1 [00:07<00:00,  7.67s/it]\n",
      "validation: 100%|██████████| 1/1 [00:10<00:00, 10.79s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0 validation: auc: 0.5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "validation: 100%|██████████| 1/1 [00:14<00:00, 14.61s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test auc: 0.25263157894736843\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "model = MyDeepFM(\n",
    "        deep_features=dense_features+sparse_features,\n",
    "        fm_features=sparse_features,\n",
    "        mlp_params={\"dims\": [256, 128], \"dropout\": 0.2, \"activation\": \"relu\"},\n",
    "    )\n",
    "# 模型训练，需要学习率、设备等一般的参数，此外我们还支持earlystoping策略，及时发现过拟合\n",
    "ctr_trainer = CTRTrainer(\n",
    "    model,\n",
    "    optimizer_params={\"lr\": 1e-4, \"weight_decay\": 1e-5},\n",
    "    n_epoch=1,\n",
    "    earlystop_patience=3,\n",
    "    device='cpu',\n",
    "    model_path='./',\n",
    ")\n",
    "ctr_trainer.fit(train_dataloader, val_dataloader)\n",
    "\n",
    "# 查看在测试集上的性能\n",
    "auc = ctr_trainer.evaluate(ctr_trainer.model, test_dataloader)\n",
    "print(f'test auc: {auc}')"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": " DeepFM.ipynb",
   "provenance": []
  },
  "interpreter": {
   "hash": "2f0699014af7f4c9080a159fe6ab9f0087a283cb8192b31d41a414a088fd29ff"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
